import os

from matplotlib import pyplot as plt
from pathlib import Path
from skimage import io

import img_util.io
import img_util.util


def plot_threshold_mask_vs_orig_img(mask_img_dir, img_basedir, cellpose_mask_dir=None,
                                    show_unthresh_mask=False, mask_file_ext='_mask.png'):
    """
    Overlay thresholded mask images onto original image for comparison. Multiple images are opened one after
    another, need to stop Python console to break out. Optional: include non-thresholded mask as
    additional plot.

    Parameters
    ----------
    mask_img_dir: str
        Directory to thresholded masks images
    img_basedir: str
        Directory to original images
    cellpose_mask_dir: str or None
        Directory to non-thresholded mask (.npy file)
    show_unthresh_mask: boool
        Default=False
    mask_file_ext: str
        Default='_mask.npy'

    Returns
    -------
    None
    """

    # Retrieve paths of all mask image files found in directory
    plt.ioff()  # Turn of interactive mode to allow figure to be inspected and code advanced by closing figure
    figsize = (16, 12)

    for mask_img_path in Path(mask_img_dir).iterdir():
        if mask_img_path.is_file() & (mask_file_ext in os.path.basename(mask_img_path)):

            # Find file path of original image
            orig_img_path = img_util.io.get_orig_img_path(mask_img_path, img_basedir, '.png')

            if not show_unthresh_mask:
                # Plot figure
                fig, ax = plt.subplots(nrows=1, ncols=2, figsize=figsize,
                                       num=os.path.splitext(os.path.basename(orig_img_path))[0],
                                       sharex='all', sharey='all')
                ax[0].imshow(io.imread(orig_img_path), cmap='gray')
                ax[0].set_title('Original image')
                ax[1].imshow(io.imread(str(mask_img_path)))
                ax[1].set_title('Thresholded masks')
                # fig.suptitle(os.path.splitext(os.path.basename(orig_img_path))[0])
                plt.tight_layout()
                fig_manager = plt.get_current_fig_manager()
                fig_manager.window.showMaximized()
                plt.show()
                plt.close(fig)

            else:

                # Find file path of unthresholded cellpose mask (.npy file)
                found_cpmask = False
                for cpmask_path in Path(cellpose_mask_dir).glob('**/*'):  # Path(cellpose_mask_dir).iterdir():

                    if cpmask_path.is_file():
                        cpmask_fname, cpmask_ext = os.path.splitext(os.path.basename(cpmask_path))
                        if (cpmask_ext == '.npy') & (cpmask_fname in os.path.basename(mask_img_path)):
                            unthresh_mask_img = img_util.util.add_npmask_to_img_via_path(str(cpmask_path),
                                                                                         orig_img_path)
                            found_cpmask = True

                            # Plot figure
                            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=figsize,
                                                   num=os.path.splitext(os.path.basename(orig_img_path))[0],
                                                   sharex='all', sharey='all')
                            ax[0].imshow(io.imread(orig_img_path), cmap='gray')
                            ax[0].set_title('Original image')
                            ax[1].imshow(unthresh_mask_img)
                            ax[1].set_title('Non-thresholded masks')
                            ax[2].imshow(io.imread(str(mask_img_path)))
                            ax[2].set_title('Thresholded masks')
                            # fig.suptitle(os.path.splitext(os.path.basename(orig_img_path))[0])
                            plt.tight_layout()
                            fig_manager = plt.get_current_fig_manager()
                            fig_manager.window.showMaximized()
                            plt.show()
                            plt.close(fig)

                if not found_cpmask:
                    raise ValueError("Failed to find original mask file (.npy) for " + os.path.basename(mask_img_path))
